/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "GPRSProtocol.h"
#include <iomanip> // setw

namespace aiengine {

GPRSProtocol::GPRSProtocol():
	Protocol("GPRS", IPPROTO_UDP) {}

bool GPRSProtocol::check(const Packet& packet) {

	int length = packet.getLength();

	if (length >= header_size) {
		setHeader(packet.getPayload());

		if (header_->flags & 0x30) {
			++total_valid_packets_;
			return true;
		}
	}
	++total_invalid_packets_;
	return false;
}

void GPRSProtocol::setDynamicAllocatedMemory(bool value) {

	gprs_info_cache_->setDynamicAllocatedMemory(value);
}

bool GPRSProtocol::isDynamicAllocatedMemory() const {

	return gprs_info_cache_->isDynamicAllocatedMemory();
}

uint64_t GPRSProtocol::getCurrentUseMemory() const {

	uint64_t mem = sizeof(GPRSProtocol);

	mem += gprs_info_cache_->getCurrentUseMemory();

	return mem;
}

uint64_t GPRSProtocol::getAllocatedMemory() const {

	uint64_t mem = sizeof(GPRSProtocol);

        mem += gprs_info_cache_->getAllocatedMemory();

        return mem;
}

uint64_t GPRSProtocol::getTotalAllocatedMemory() const {

        return getAllocatedMemory();
}

uint32_t GPRSProtocol::getTotalCacheMisses() const {

	return gprs_info_cache_->getTotalFails();
}

void GPRSProtocol::increaseAllocatedMemory(int value) {

	gprs_info_cache_->create(value);
}

void GPRSProtocol::decreaseAllocatedMemory(int value) {

	gprs_info_cache_->destroy(value);
}

void GPRSProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

                uint64_t total_bytes_released_by_flows = 0;
                uint32_t release_flows = 0;

                for (auto &flow: ft) {
			if (SharedPointer<GPRSInfo> info = flow->getGPRSInfo(); info) {
                                flow->layer4info.reset();
                                total_bytes_released_by_flows += info->getIMSIString().size() + 16; // 16 bytes from the uint16_t
                                gprs_info_cache_->release(info);
                                ++release_flows;
                        }
                }

                std::string funit = "Bytes";

		data_time_ = boost::posix_time::microsec_clock::local_time();

                unitConverter(total_bytes_released_by_flows, funit);

                msg.str("");
                msg << "Release " << release_flows << " flows";
                msg << ", flow " << total_bytes_released_by_flows << " " << funit;
                infoMessage(msg.str());
        }
}

void GPRSProtocol::releaseFlowInfo(Flow *flow) {

	if (auto info = flow->getGPRSInfo(); info)
		gprs_info_cache_->release(info);
}

void GPRSProtocol::process_create_pdp_context(Flow *flow) {

	// Verify that the length of the PDP context is valid
	int32_t packet_length = flow->packet->getLength();

	if (packet_length >= (int)sizeof(gprs_create_pdp_header) + (int)sizeof(gprs_header)) {

		SharedPointer<GPRSInfo> gprs_info = flow->getGPRSInfo();
		if (!gprs_info) {
			if (gprs_info = gprs_info_cache_->acquire(); gprs_info)
				flow->layer4info = gprs_info;
		}

		if (gprs_info) {
			const gprs_create_pdp_header *cpd = reinterpret_cast<const gprs_create_pdp_header*>(header_->data);
			const uint8_t *extensions = &cpd->data[0];
			uint8_t token = extensions[0];

			if (cpd->presence == 0x02) {
				gprs_info->setIMSI(cpd->un.reg.imsi);
				extensions = &cpd->un.reg.hdr[0];
				token = extensions[0];
			}else {
				// And extension header
				if (cpd->presence == 0x01) {
					extensions = &cpd->data[0];
					token = extensions[0];
					gprs_info->setIMSI(cpd->un.ext.imsi);
				}
			}

			if (token == 0x03) { // Routing Area Identity Header
				const gprs_create_pdp_header_routing *rhdr = reinterpret_cast<const gprs_create_pdp_header_routing*>(extensions);
				extensions = &rhdr->data[0];
				token = extensions[0];
			}

			if (token == 0x0E) { // Recovery
				extensions = &extensions[2];
				token = extensions[0];
			}
			if (token == 0x0F) {
				const gprs_create_pdp_header_ext *hext = reinterpret_cast<const gprs_create_pdp_header_ext*>(&extensions[2]);
				extensions = &hext->data[0];
				token = extensions[0];

				if (token == 0x1A) { // Charging Characteristics
					token = extensions[3];
					extensions = &extensions[4];
				} else {
					extensions = &extensions[1];
				}
				if (token == 0x80) {
					uint16_t length = ntohs((extensions[1] << 8) + extensions[0]);
					if (length == 2) {
						uint8_t type_org __attribute__((unused)) = extensions[2];
						uint8_t type_num = extensions[3];
						// type_num eq 0x21 is IPv4
						// type_num eq 0x57 is IPv6

						gprs_info->setPdpTypeNumber(type_num);
					}
				}
			}
		}
	}
}

void GPRSProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	int bytes = flow->packet->getLength();
        total_bytes_ += bytes;
	++total_packets_;

        if (!mux_.expired()&&(bytes >= header_size)) {
		const uint8_t *payload = flow->packet->getPayload();
		setHeader(payload);

		uint8_t type = header_->type;
		int8_t version = header_->flags >> 5;

		if ((type == T_PDU)and(version == 1)) {
			MultiplexerPtr mux = mux_.lock();

			Packet gpacket(*(flow->packet));

			int offset = 0;

			// Not sure if seen this on user data
			if (haveExtensionHeader())
				offset += 6; // sizeof extension headers

			if (haveSequenceNumber())
				offset += 4;

			gpacket.setPayload(&payload[offset]);
			gpacket.setPrevHeaderSize(header_size + offset);

			mux->setNextProtocolIdentifier(ip_protocol_type_);
			mux->forward(gpacket);

			if (gpacket.haveEvidence())
				flow->packet->setEvidence(gpacket.haveEvidence());

			++total_tpdus_;
		} else if (type == CREATE_PDP_CONTEXT_REQUEST) {
			process_create_pdp_context(flow);
			++total_create_pdp_ctx_requests_;
		} else if (type == CREATE_PDP_CONTEXT_RESPONSE) {
			++total_create_pdp_ctx_responses_;
		} else if (type == UPDATE_PDP_CONTEXT_REQUEST) {
			++total_update_pdp_ctx_requests_;
		} else if (type == UPDATE_PDP_CONTEXT_RESPONSE) {
			++total_update_pdp_ctx_responses_;
		} else if (type == DELETE_PDP_CONTEXT_REQUEST) {
			// TODO shutdown the flow
			++total_delete_pdp_ctx_requests_;
		} else if (type == DELETE_PDP_CONTEXT_RESPONSE) {
			// TODO shutdown the flow
			++total_delete_pdp_ctx_responses_;
		} else if (type == GPRS_ECHO_REQUEST) {
			++total_echo_requests_;
		} else if (type == GPRS_ECHO_RESPONSE) {
			++total_echo_responses_;
		}
         }
}

void GPRSProtocol::statistics(std::basic_ostream<char>& out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

	if (level > 3) {
		out << "\t" << "Total echo reqs:        " << std::setw(10) << total_echo_requests_ << "\n"
			<< "\t" << "Total echo ress:        " << std::setw(10) << total_echo_responses_ << "\n"
			<< "\t" << "Total create pdp reqs:  " << std::setw(10) << total_create_pdp_ctx_requests_ << "\n"
			<< "\t" << "Total create pdp ress:  " << std::setw(10) << total_create_pdp_ctx_responses_ << "\n"
			<< "\t" << "Total update pdp reqs:  " << std::setw(10) << total_update_pdp_ctx_requests_ << "\n"
			<< "\t" << "Total update pdp ress:  " << std::setw(10) << total_update_pdp_ctx_responses_ << "\n"
			<< "\t" << "Total delete pdp reqs:  " << std::setw(10) << total_delete_pdp_ctx_requests_ << "\n"
			<< "\t" << "Total delete pdp ress:  " << std::setw(10) << total_delete_pdp_ctx_responses_ << "\n"
			<< "\t" << "Total tpdus:          " << std::setw(12) << total_tpdus_ << std::endl;
	}
	if (level > 5) {
		if (mux_.lock())
			mux_.lock()->statistics(out);
		if (flow_forwarder_.lock())
			flow_forwarder_.lock()->statistics(out);
	}
	if ((level > 3)and(gprs_info_cache_))
		gprs_info_cache_->statistics(out);
}

void GPRSProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

	if (level > 3) {
                out["echo_requests"] = total_echo_requests_;
		out["echo_responses"] = total_echo_responses_;
		out["create_pdp_requests"] = total_create_pdp_ctx_requests_;
		out["create_pdp_responses"] = total_create_pdp_ctx_responses_;
		out["update_pdp_requests"] = total_update_pdp_ctx_requests_;
		out["update_pdp_responses"] = total_update_pdp_ctx_responses_;
		out["delete_pdp_requests"] = total_delete_pdp_ctx_requests_;
		out["delete_pdp_responses"] = total_delete_pdp_ctx_responses_;
		out["tpdus"] = total_tpdus_;
        }
}

CounterMap GPRSProtocol::getCounters() const {
   	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);
        cm.addKeyValue("echo reqs", total_echo_requests_);
        cm.addKeyValue("echo ress", total_echo_requests_);
        cm.addKeyValue("create pdp reqs", total_create_pdp_ctx_requests_);
        cm.addKeyValue("create pdp ress", total_create_pdp_ctx_responses_);
        cm.addKeyValue("update pdp reqs", total_update_pdp_ctx_requests_);
        cm.addKeyValue("update pdp ress", total_update_pdp_ctx_responses_);
        cm.addKeyValue("delete pdp reqs", total_delete_pdp_ctx_requests_);
        cm.addKeyValue("delete pdp ress", total_delete_pdp_ctx_responses_);
        cm.addKeyValue("tpdus", total_tpdus_);

        return cm;
}

void GPRSProtocol::resetCounters() {

	reset();

        total_create_pdp_ctx_requests_ = 0;
        total_create_pdp_ctx_responses_ = 0;
        total_update_pdp_ctx_requests_ = 0;
        total_update_pdp_ctx_responses_ = 0;
        total_delete_pdp_ctx_requests_ = 0;
        total_delete_pdp_ctx_responses_ = 0;
        total_tpdus_ = 0;
        total_echo_requests_ = 0;
        total_echo_responses_ = 0;
}

} // namespace aiengine

